
from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Dict, List

# ---------------------------------------------------------------------------- #
# Topic vocabulary & mapping helpers                                           #
# ---------------------------------------------------------------------------- #

TOPIC_VOCAB: Dict[str, str] = {
    "restaurant": "Restaurants, cafe, fast food",
    "chocolate": "Chocolate, cookies, candy, ice cream",
    "chips": "Chips, snacks, nuts, fruit, gum, cereal, yogurt, soups",
    "seasoning": "Seasoning, condiments, ketchup",
    "petfood": "Petfood",
    "alcohol": "Alcohol",
    "coffee": "Coffee, tea",
    "soda": "Soda, juice, milk, energy drinks, water",
    "cars": "Cars, automobile sales, parts, insurance, repair, gas, motor oil",
    "electronics": "Electronics (computers, laptops, tablets, cellphones, TVs)",
    "phone_tv_internet_providers": "Phone, TV and internet service providers",
    "financial": "Financial services (banks, credit cards, investment firms)",
    "education": "Education (universities, colleges, kindergarten, online degrees)",
    "security": "Security and safety services (anti-theft, safety courses)",
    "software": "Software (internet radio, streaming, job search website, grammar correction, travel planning)",
    "other_service": "Other services (dating, tax, legal, loan, religious, printing, catering)",
    "beauty": "Beauty products and cosmetics (deodorants, toothpaste, makeup, hair products, laser hair removal)",
    "healthcare": "Healthcare and medications (hospitals, health insurance, allergy, cold remedy, home tests, vitamins)",
    "clothing": "Clothing and accessories (jeans, shoes, eye glasses, handbags, watches, jewelry)",
    "baby": "Baby products (food, sippy cups, diapers)",
    "game": "Games and toys (including video and mobile games)",
    "cleaning": "Cleaning products (detergents, fabric softeners, soap, tissues, paper towels)",
    "home_improvement": "Home improvements and repairs (furniture, decoration, lawn care, plumbing)",
    "home_appliance": "Home appliances (coffee makers, dishwashers, cookware, vacuum cleaners, heaters, music players)",
    "travel": "Vacation and travel (airlines, cruises, theme parks, hotels, travel agents)",
    "media": "Media and arts (TV shows, movies, musicals, books, audio books)",
    "sports": "Sports equipment and activities",
    "shopping": "Shopping (department stores, drug stores, groceries)",
    "gambling": "Gambling (lotteries, casinos)",
    "environment": "Environment, nature, pollution, wildlife",
    "animal_right": "Animal rights, animal abuse",
    "human_right": "Human rights",
    "safety": "Safety, safe driving, fire safety",
    "smoking_alcohol_abuse": "Smoking, alcohol abuse",
    "domestic_violence": "Domestic violence",
    "self_esteem": "Self esteem, bullying, cyber bullying",
    "political": "Political candidates",
    "charities": "Charities",
}

# The order **must** match the indices used within annotation.json (1-based).
# If you change `TOPIC_ORDER`, make sure the order aligns with the option
# indices present in the ground-truth annotation.
TOPIC_ORDER: List[str] = [
    "restaurant",
    "chocolate",
    "chips",
    "seasoning",
    "petfood",
    "alcohol",
    "coffee",
    "soda",
    "cars",
    "electronics",
    "phone_tv_internet_providers",
    "financial",
    "education",
    "security",
    "software",
    "other_service",
    "beauty",
    "healthcare",
    "clothing",
    "baby",
    "game",
    "cleaning",
    "home_improvement",
    "home_appliance",
    "travel",
    "media",
    "sports",
    "shopping",
    "gambling",
    "environment",
    "animal_right",
    "human_right",
    "safety",
    "smoking_alcohol_abuse",
    "domestic_violence",
    "self_esteem",
    "political",
    "charities"
]

OPTION_TO_TOPIC: Dict[int, str] = {idx + 1: topic for idx, topic in enumerate(TOPIC_ORDER)}


# ---------------------------------------------------------------------------- #
# Utility functions                                                             #
# ---------------------------------------------------------------------------- #


def load_ground_truth(path: Path) -> Dict[str, str]:
    """Load ground-truth mapping *video_id* -> *topic string*.

    Parameters
    ----------
    path : Path
        Path to *annotation.json* file.

    Returns
    -------
    Dict[str, str]
        Dictionary mapping video id to canonical topic key.
    """
    with path.open("r", encoding="utf-8") as f:
        raw: Dict[str, int] = json.load(f)

    gt: Dict[str, str] = {}
    for vid, opt in raw.items():
        if not isinstance(opt, int):
            # Sometimes numbers are encoded as strings; try to cast.
            try:
                opt_int = int(opt)
            except (TypeError, ValueError):
                raise ValueError(f"Invalid option value {opt!r} for video {vid!r} in ground truth.")
        else:
            opt_int = opt

        topic = OPTION_TO_TOPIC.get(opt_int)
        if topic is None:
            raise KeyError(f"No topic mapping found for option {opt_int} (video id: {vid}).")
        gt[vid] = topic
    return gt


def load_predictions(pred_dir: Path) -> List[tuple[str, str]]:
    """Load predictions from **all** *.json* files inside *pred_dir*.

    Returns a list of *(video_id, topic)* pairs **including duplicates** so that
    we can compute accuracy both with and without deduplication.
    """
    records_all: List[tuple[str, str]] = []

    for json_path in pred_dir.glob("*.json"):
        try:
            with json_path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except json.JSONDecodeError as e:
            print(f"[WARN] Failed to parse {json_path.name}: {e}")
            continue

        # Some prediction files are a single object, others are a list of objects.
        if isinstance(data, list):
            recs_in_file = data
        else:
            recs_in_file = [data]

        for rec in recs_in_file:
            if not isinstance(rec, dict):
                print(f"[WARN] Unexpected record type in {json_path.name}: {type(rec).__name__}; skipping.")
                continue

            video_id: str | None = rec.get("video_id")
            if not video_id:
                video_id = json_path.stem  # fallback

            topic: str | None = rec.get("final_topic")
            if not topic:
                print(f"[WARN] Missing 'final_topic' for video {video_id} in {json_path.name}; skipping.")
                continue

            records_all.append((video_id, topic))

    return records_all


# ------------------------------------------------------------------------- #
# Accuracy helpers
# ------------------------------------------------------------------------- #


def compute_accuracy_records(records: List[tuple[str, str]], gt: Dict[str, str]) -> tuple[int, int]:
    """Compute (correct, total) for a list of prediction records (may contain duplicates)."""
    correct = 0
    for vid, topic_pred in records:
        if gt.get(vid) == topic_pred:
            correct += 1
    return correct, len(records)


def compute_accuracy_unique(pred_unique: Dict[str, str], gt: Dict[str, str]) -> tuple[int, int]:
    """Compute (correct, total) using a mapping of unique video_id -> topic."""
    correct = 0
    total = 0
    for vid, topic_pred in pred_unique.items():
        topic_gt = gt.get(vid)
        if topic_gt is None:
            continue  # Unknown ID in ground truth
        total += 1
        if topic_pred == topic_gt:
            correct += 1
    return correct, total


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate topic prediction accuracy.")
    parser.add_argument("--pred_dir", type=str, required=True, help="Directory containing prediction JSON files.")
    parser.add_argument("--annot_file", type=str, required=True, help="Path to topic_annotation.json ground-truth file.")
    parser.add_argument("--output", type=str, default="np_metric.txt", help="File to write accuracy to (default: metric.txt)")
    args = parser.parse_args()

    pred_dir = Path(args.pred_dir)
    annot_path = Path(args.annot_file)
    out_path = Path(args.output)

    if not pred_dir.is_dir():
        raise NotADirectoryError(f"Prediction directory not found: {pred_dir}")
    if not annot_path.is_file():
        raise FileNotFoundError(f"Annotation file not found: {annot_path}")

    # Load ground truth once
    gt_map = load_ground_truth(annot_path)

    def load_predictions_file(path: Path) -> List[tuple[str, str]]:
        """Load predictions from a single JSON file."""
        try:
            with path.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] Could not read {path.name}: {e}")
            return []

        recs = data if isinstance(data, list) else [data]
        pairs: List[tuple[str, str]] = []
        for rec in recs:
            if not isinstance(rec, dict):
                continue
            vid = rec.get("video_id") or path.stem
            topic = rec.get("final_topic")
            if not topic:
                continue
            pairs.append((vid, topic))
        return pairs

    with out_path.open("w", encoding="utf-8") as f_out:
        for json_path in sorted(pred_dir.glob("*.json")):
            pred_records = load_predictions_file(json_path)

            # duplicates accuracy
            c_dup, t_dup = compute_accuracy_records(pred_records, gt_map)
            acc_dup = (c_dup / t_dup) if t_dup else 0.0

            # unique accuracy
            unique_map: Dict[str, str] = {}
            for vid, topic in pred_records:
                unique_map[vid] = topic
            c_u, t_u = compute_accuracy_unique(unique_map, gt_map)
            acc_u = (c_u / t_u) if t_u else 0.0

            line = f"{json_path.name}: dup {acc_dup:.4f} | unique {acc_u:.4f}"
            print(line)
            f_out.write(line + "\n")


if __name__ == "__main__":
    main()
